#  Copyright (c) ZenML GmbH 2022. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at:
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.
"""Kubernetes pod settings."""

from typing import Any, Dict, List, Optional

from pydantic import field_validator

from zenml.config.base_settings import BaseSettings
from zenml.integrations.kubernetes import serialization_utils
from zenml.logger import get_logger

logger = get_logger(__name__)


_pod_settings_logged_warnings = []


def warn_if_invalid_model_data(data: Any, class_name: str) -> None:
    """Validates the data of a Kubernetes model.

    Args:
        data: The data to validate.
        class_name: Name of the class of the model.
    """
    if not isinstance(data, dict):
        return

    try:
        serialization_utils.deserialize_kubernetes_model(data, class_name)
    except KeyError as e:
        if str(e) not in _pod_settings_logged_warnings:
            _pod_settings_logged_warnings.append(str(e))
            logger.warning(
                "Invalid data for Kubernetes model class `%s`: %s. "
                "Hint: Kubernetes expects attribute names in CamelCase, not "
                "snake_case.",
                class_name,
                e,
            )


class KubernetesPodSettings(BaseSettings):
    """Kubernetes Pod settings.

    Attributes:
        node_selectors: Node selectors to apply to the pod.
        affinity: Affinity to apply to the pod.
        tolerations: Tolerations to apply to the pod.
        resources: Resource requests and limits for the pod.
        annotations: Annotations to apply to the pod metadata.
        volumes: Volumes to mount in the pod.
        volume_mounts: Volume mounts to apply to the pod containers.
        host_ipc: Whether to enable host IPC for the pod.
        scheduler_name: The name of the scheduler to use for the pod.
        image_pull_secrets: Image pull secrets to use for the pod.
        labels: Labels to apply to the pod.
        env: Environment variables to apply to the container.
        env_from: Environment variables to apply to the container.
        container_security_context: Security context settings to apply to all
            containers in the pod. This allows specifying container-level security
            attributes such as runAsUser, runAsNonRoot, allowPrivilegeEscalation,
            etc. Note: This is different from pod-level security context (which can
            be set via additional_pod_spec_args) as some admission policies require
            container-level security settings.
        additional_pod_spec_args: Additional arguments to pass to the pod. These
            will be applied to the pod spec.
    """

    node_selectors: Dict[str, str] = {}
    affinity: Dict[str, Any] = {}
    tolerations: List[Dict[str, Any]] = []
    resources: Dict[str, Dict[str, str]] = {}
    annotations: Dict[str, str] = {}
    volumes: List[Dict[str, Any]] = []
    volume_mounts: List[Dict[str, Any]] = []
    host_ipc: bool = False
    scheduler_name: Optional[str] = None
    image_pull_secrets: List[str] = []
    labels: Dict[str, str] = {}
    env: List[Dict[str, Any]] = []
    env_from: List[Dict[str, Any]] = []
    container_security_context: Dict[str, Any] = {}
    additional_pod_spec_args: Dict[str, Any] = {}

    @field_validator("volumes", mode="before")
    @classmethod
    def _convert_volumes(cls, value: Any) -> Any:
        """Converts Kubernetes volumes to dicts.

        Args:
            value: The volumes list.

        Returns:
            The converted volumes.
        """
        from kubernetes.client.models import V1Volume

        result = []
        for element in value:
            if isinstance(element, V1Volume):
                result.append(
                    serialization_utils.serialize_kubernetes_model(element)
                )
            else:
                warn_if_invalid_model_data(element, "V1Volume")
                result.append(element)

        return result

    @field_validator("volume_mounts", mode="before")
    @classmethod
    def _convert_volume_mounts(cls, value: Any) -> Any:
        """Converts Kubernetes volume mounts to dicts.

        Args:
            value: The volume mounts list.

        Returns:
            The converted volume mounts.
        """
        from kubernetes.client.models import V1VolumeMount

        result = []
        for element in value:
            if isinstance(element, V1VolumeMount):
                result.append(
                    serialization_utils.serialize_kubernetes_model(element)
                )
            else:
                warn_if_invalid_model_data(element, "V1VolumeMount")
                result.append(element)

        return result

    @field_validator("affinity", mode="before")
    @classmethod
    def _convert_affinity(cls, value: Any) -> Any:
        """Converts Kubernetes affinity to a dict.

        Args:
            value: The affinity value.

        Returns:
            The converted value.
        """
        from kubernetes.client.models import V1Affinity

        if isinstance(value, V1Affinity):
            return serialization_utils.serialize_kubernetes_model(value)
        else:
            warn_if_invalid_model_data(value, "V1Affinity")
            return value

    @field_validator("tolerations", mode="before")
    @classmethod
    def _convert_tolerations(cls, value: Any) -> Any:
        """Converts Kubernetes tolerations to dicts.

        Args:
            value: The tolerations list.

        Returns:
            The converted tolerations.
        """
        from kubernetes.client.models import V1Toleration

        result = []
        for element in value:
            if isinstance(element, V1Toleration):
                result.append(
                    serialization_utils.serialize_kubernetes_model(element)
                )
            else:
                warn_if_invalid_model_data(element, "V1Toleration")
                result.append(element)

        return result

    @field_validator("resources", mode="before")
    @classmethod
    def _convert_resources(cls, value: Any) -> Any:
        """Converts Kubernetes resource requirements to a dict.

        Args:
            value: The resource value.

        Returns:
            The converted value.
        """
        from kubernetes.client.models import V1ResourceRequirements

        if isinstance(value, V1ResourceRequirements):
            return serialization_utils.serialize_kubernetes_model(value)
        else:
            warn_if_invalid_model_data(value, "V1ResourceRequirements")
            return value

    @field_validator("env", mode="before")
    @classmethod
    def _convert_env(cls, value: Any) -> Any:
        """Converts Kubernetes EnvVar to a dict.

        Args:
            value: The env value.

        Returns:
            The converted value.
        """
        from kubernetes.client.models import V1EnvVar

        result = []
        for element in value:
            if isinstance(element, V1EnvVar):
                result.append(
                    serialization_utils.serialize_kubernetes_model(element)
                )
            else:
                warn_if_invalid_model_data(element, "V1EnvVar")
                result.append(element)

        return result

    @field_validator("env_from", mode="before")
    @classmethod
    def _convert_env_from(cls, value: Any) -> Any:
        """Converts Kubernetes EnvFromSource to a dict.

        Args:
            value: The env from value.

        Returns:
            The converted value.
        """
        from kubernetes.client.models import V1EnvFromSource

        result = []
        for element in value:
            if isinstance(element, V1EnvFromSource):
                result.append(
                    serialization_utils.serialize_kubernetes_model(element)
                )
            else:
                warn_if_invalid_model_data(element, "V1EnvFromSource")
                result.append(element)

        return result

    @field_validator("container_security_context", mode="before")
    @classmethod
    def _convert_container_security_context(cls, value: Any) -> Any:
        """Converts Kubernetes SecurityContext to a dict.

        Args:
            value: The container security context value.

        Returns:
            The converted value.
        """
        from kubernetes.client.models import V1SecurityContext

        if isinstance(value, V1SecurityContext):
            return serialization_utils.serialize_kubernetes_model(value)
        else:
            warn_if_invalid_model_data(value, "V1SecurityContext")
            return value
